import numpy as np
import itertools
import matplotlib.pyplot as plt
import matplotlib

class stochastic_cliff():
    def __init__(self):
        self.noise = 0.5  # Probability of a random action
        self.size = 10  # 20
        self.num_states = self.size**2
        self.num_actions = 4
        self.gamma = 0.9
        self.WALLS = np.zeros((self.size, self.size))
        self.WALLS[0:6, 4:6] = 1
        self.goal_i, self.goal_j = (0, self.size - 1)

        self.A_DELTA = np.array([
            [-1, 0],
            [0, 1],
            [1, 0],
            [0, -1]
        ])

    def ij_to_index(self,i, j):
      return i * self.size + j

    def index_to_ij(self,index):
      return (index // self.size, index % self.size)


    def dynamics(self, index, a):
      i, j = self.index_to_ij(index)
      (i, j) = (i, j) + self.A_DELTA[a]
      (i, j) = np.clip((i, j), 0, self.size - 1)
      # if self.WALLS[i, j]:
      #   i, j = self.index_to_ij(0)
      index2 = self.ij_to_index(i, j)
      return index2

    def plot_dynamics2(self,q_s_given_sa, action, path):
        size = 10
        num_states = size ** 2
        num_actions = 4
        WALLS = np.zeros((size, size))
        WALLS[0:6, 4:6] = 1
        plt.figure(figsize=(8, 8))
        plt.imshow(WALLS, cmap='Greys', extent=(0, size, 0, size), alpha=0.5,origin='lower')
        plt.imshow(path, cmap='Reds', extent=(0, size, 0, size), alpha=0.5, origin='lower')
        # plt.show()
        for index, a in itertools.product(range(num_states), range(num_actions)):
            index2 = self.dynamics(index, a)
            p = q_s_given_sa[index, action, index2]
            i, j = self.index_to_ij(index)
            if not WALLS[i, j]:
                di, dj = 0.4 * p * self.A_DELTA[a]
                c = matplotlib.cm.Blues(p)
                plt.arrow(j + 0.5, i + 0.5, dj, di, length_includes_head=False, width=0.03,
                          fc=c, ec=c)
        plt.xticks(np.arange(0,10,1))
        plt.yticks(np.arange(0, 10, 1))
        plt.grid()
        plt.plot([self.goal_j + 0.5], [self.goal_i + 0.5], 'g*', ms=50)
        plt.xlim([0, size])
        plt.ylim([size, 0])
        # plt.title(name)
        plt.show()

    def generate_env(self):
        p_s_given_sa = np.zeros([self.num_states, self.num_actions, self.num_states])
        for index, a in itertools.product(range(self.num_states), range(self.num_actions)):
          for a2 in range(self.num_actions):
            index2 = self.dynamics(index, a2)
            p_s_given_sa[index, a, index2] += self.noise / self.num_actions
          index2 = self.dynamics(index, a)
          p_s_given_sa[index, a, index2] += (1 - self.noise)

        # define the reward function
        r_sa = np.zeros([self.num_states, self.num_actions]) + 0.001
        goal_i, goal_j = (0, self.size - 1)
        index = self.ij_to_index(goal_i, goal_j)
        r_sa[index] = 10.0
        for i in range(self.size):
            for j in range(self.size):
                if(self.WALLS[i,j] == 1):
                    index = self.ij_to_index(i, j)
                    r_sa[index] = -1
                    p_s_given_sa[index, :, :] = 0
                    p_s_given_sa[index, :, 0] = 1
        # r_sa[self.WALLS[0:6, 4:6]] = -1
        # r_sa[index] = 10.0
         # Sparse reward for reaching the goal state
        return p_s_given_sa.copy(), r_sa.copy()
